import pickle
import tempfile
import os
import time

import numpy as np
import torch
from cpprb import ReplayBuffer, create_before_add_func, create_env_dict
from kmeans_pytorch import kmeans
from sklearn.cluster import KMeans, AffinityPropagation
from sklearn.metrics import silhouette_score, davies_bouldin_score
from sacred import Ingredient
from torch import nn
from torch.utils.data import DataLoader, Dataset
from tqdm import tqdm
import matplotlib.pyplot as plt
import matplotlib.colors as mcolors
from scipy.optimize import linear_sum_assignment
from collections import defaultdict, Counter

from model_Het import ConditionalLinearVAE, BiHierarchicalVAE

ops_ingredient = Ingredient("ops")

@ops_ingredient.config
def config():
    # -----------------------------for normal measuring-----------------------------

    delay_training_seps = True
    delay_steps = 2000 # delay steps for both training and measuring, mainly for SePS
    pretraining_times_seps = 1 


    reparameter_interval = 20 # for simplicity, reparameter_steps is the same as measure_interval
    max_rb_steps = 20  # max_rb_steps should be not larger than reparameter_steps
    reparameter_times = 200000 # automata all the time

    policy_mask = ["act_mask"]
    model_count = 6
    initial_as_the_same = False # for all the policy models
    

    batch_size = 128
    lr = 5e-4
    epochs = 10
    z_features = 40
    kl_weight = 0.01

    # for Policy Distance Computing
    encoder_in = ["act_probs"]
    encoder_condition = ["origin_obs"]
    decoder_in = ["origin_obs"] # and actually a sampled "z"
    # reconstruct = ["act_probs"]
    reconstruct = ["act_probs_with_mask"]
    other_1 = ["agent"]
    other_2 = ["local_state", "act"]


    # -----------------------------for Meta-Het Computing-----------------------------
    encoder_in_IHet = ["agent"]
    encoder_condition_IHet = ["local_state", "act"]
    decoder_in_IHet = ["local_state", "act"]
    reconstruct_IHet = ["next_local_state", "next_obs", "rew_pad"]

    lr_IHet = 5e-4
    epochs_IHet = 10
    z_features_IHet = 40
    kl_weight_IHet = 0.0001




class rbDataSet(Dataset):
    @ops_ingredient.capture
    def __init__(self, rb, encoder_in, decoder_in, reconstruct, policy_mask, other_1, other_2):
        self.rb = rb
        self.data = []
        self.data.append(torch.cat([torch.from_numpy(self.rb[n]) for n in encoder_in], dim=1))
        self.data.append(torch.cat([torch.from_numpy(self.rb[n]) for n in decoder_in], dim=1))
        self.data.append(torch.cat([torch.from_numpy(self.rb[n]) for n in reconstruct], dim=1))
        self.data.append(torch.cat([torch.from_numpy(self.rb[n]) for n in policy_mask], dim=1))
        self.data.append(torch.cat([torch.from_numpy(self.rb[n]) for n in other_1], dim=1))
        self.data.append(torch.cat([torch.from_numpy(self.rb[n]) for n in other_2], dim=1))
        
        print([x.shape for x in self.data])
    def __len__(self):
        return self.data[0].shape[0]
    def __getitem__(self, idx):
        return [x[idx, :] for x in self.data]
    

class rbDataSet_IHet(Dataset):
    @ops_ingredient.capture
    def __init__(self, rb, encoder_in_IHet, encoder_condition_IHet, decoder_in_IHet, reconstruct_IHet):
        self.rb = rb
        self.data = []
        self.data.append(torch.cat([torch.from_numpy(self.rb[n]) for n in encoder_in_IHet], dim=1))
        self.data.append(torch.cat([torch.from_numpy(self.rb[n]) for n in encoder_condition_IHet], dim=1))
        self.data.append(torch.cat([torch.from_numpy(self.rb[n]) for n in decoder_in_IHet], dim=1))
        self.data.append(torch.cat([torch.from_numpy(self.rb[n]) for n in reconstruct_IHet], dim=1))
        
        print([x.shape for x in self.data])
    def __len__(self):
        return self.data[0].shape[0]
    def __getitem__(self, idx):
        return [x[idx, :] for x in self.data]





@ops_ingredient.capture
def compute_fusions(rb, agent_count, policy_model, batch_size, lr, epochs, z_features, kl_weight, measure_mode=True, policy_mode='Pure', policy_submodel=None):
    '''
    This is a model-based method, so it needs to use policy_model, and even the auxiliary Policy VAE model
    If Policy-model has obs+content concatenated input format:
    1. VAE training input is original obs
    2. Overall Policy input is original obs + content
    3. FPS+id input is original obs + id
    '''
    assert measure_mode
    assert policy_mode in ['Pure', 'WithID'], "policy_mode must be in ['Pure', 'WithID']"
    # ---------------------------1. train Conditional VAE------------------------------
    print("Starting to train Conditional VAE for MAPD")
    device = next(policy_model.parameters()).device

    dataset = rbDataSet(rb)

    encoder_input_size = dataset.data[0].shape[-1]
    encoder_condition_size = dataset.data[1].shape[-1]
    reconstruct_size = dataset.data[2].shape[-1]
    # assert reconstruct_size == encoder_input_size 
    assert encoder_input_size == dataset.data[3].shape[-1]
   
    VAE_model = ConditionalLinearVAE(z_features, encoder_input_size, encoder_condition_size, reconstruct_size)
    print(VAE_model)
    VAE_model.to(device)
    optimizer = torch.optim.Adam(VAE_model.parameters(), lr=lr)

    criterion = nn.MSELoss(reduction="sum")

    def final_loss(bce_loss, mu, logvar):
        """
        This function will add the reconstruction loss (BCELoss) and the 
        KL-Divergence.
        KL-Divergence = 0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2)
        :param bce_loss: recontruction loss
        :param mu: the mean from the latent vector
        :param logvar: log variance from the latent vector
        """
        BCE = bce_loss 
        KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
        return BCE + kl_weight*KLD
    
    def fit(model, dataloader):
        model.train()
        running_loss = 0.0

        for i, (encoder_in, encoder_condition, y, _, _, _) in enumerate(dataloader):
            (encoder_in, encoder_condition, y) = (encoder_in.to(device), encoder_condition.to(device), y.to(device))
            optimizer.zero_grad()
            reconstruction, mu, logvar = model(encoder_in, encoder_condition)
            bce_loss = criterion(reconstruction, y)
            loss = final_loss(bce_loss, mu, logvar)
            running_loss += loss.item()
            loss.backward()
            optimizer.step()
        train_loss = running_loss/len(dataloader.dataset)
        return train_loss
    
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
    train_loss = []


    previous_time = time.time()

    for epoch in tqdm(range(epochs)):
        train_epoch_loss = fit(VAE_model, dataloader)
        train_loss.append(train_epoch_loss)
    current_time = time.time()
    print(f"Train Loss: {train_epoch_loss:.6f}")

    time_difference = current_time - previous_time
    print("Training time:", time_difference)




    # ------------------------------2. Use VAE and agents' models to calculate dij------------------------------
    print("--------------Starting to calculate MAPD distance matrix--------------")
    len1 = 0
    len2 = None

    BD_storage = None
    Hellinger_storage = None
    WD_storage = None
    batch_count = len(dataloader)

    previous_time = time.time()

    
    for i, (_, policy_input, _, _, other_1, other_2) in enumerate(dataloader):

        # A. get policy inputs
        B = policy_input.shape[0]
        N = agent_count                        # agents_count
        E = policy_model.laac_shallow.shape[0]  # envs_count, this dimension is forced to be used due to Policy format

        if policy_mode == 'Pure':
            assert policy_submodel is None, "policy_submodel must be None when policy_mode is 'Pure'"
            hidden_sample_batch = policy_input.to(device) # [B,H]
            hidden_sample_list = [hidden_sample_batch.clone().unsqueeze(1).expand(-1, E, -1) for _ in range(N)]     # (N)[B,E,H]
        elif policy_mode == 'WithID':
            assert policy_submodel is None, "policy_submodel must be None when policy_mode is 'WithID'"
            hidden_sample_batch = policy_input.to(device)  # [B,H]
            origin_obs = hidden_sample_batch.unsqueeze(0).expand(N, -1, -1)  # [N,B,H]
            agent_ids = torch.eye(N, device=device).unsqueeze(1).expand(-1, B, -1)  # [N,B,N]
            obs_expanded = hidden_sample_batch.unsqueeze(0).expand(N, -1, -1)  # [N,B,H]
            agent_inputs = torch.cat((obs_expanded, agent_ids), dim=-1)  # [N,B,H+N]
            hidden_sample_list = [agent_inputs[i].unsqueeze(1).expand(-1, E, -1) for i in range(N)]  # (N)[B,E,H+N]

        # B. get policy outputs
        with torch.no_grad():
            act_probs = policy_model.get_act_probs(hidden_sample_list) # (N)[B,E,A]
    
        VAE_input_batch = torch.stack(act_probs, dim=0)[:,:,0,:]    # [N, B, A]
        VAE_condition_batch = torch.stack(hidden_sample_list, dim=0)[:,:,0,:] if policy_mode == 'Pure' else origin_obs 
         # [N, B, H]

        # C. compute Dij (only one env is enough)
        with torch.no_grad():
            z, mus, sigmas = VAE_model.encode(VAE_input_batch, VAE_condition_batch)     # [N, B, D]

        N, B, D = mus.shape

        if BD_storage is None: BD_storage = torch.zeros([N,N]).to(device)
        if Hellinger_storage is None: Hellinger_storage = torch.zeros([N,N]).to(device)
        if WD_storage is None: WD_storage = torch.zeros([N,N]).to(device)

        BD = calculate_N_Gaussians_BD(mus.transpose(0,1),sigmas.transpose(0,1)) # [B,N,D] --> [N,N]
        Hellinger = calculate_N_Gaussians_Hellinger_through_BD(BD)
        WD = calculate_N_Gaussians_WD(mus.transpose(0,1),sigmas.transpose(0,1))

        BD_storage += BD
        Hellinger_storage += Hellinger
        WD_storage += WD


    
    current_time = time.time()
    time_difference = current_time - previous_time
    print("Time needed for distance matrix calculation:", time_difference)
    
    
    (BD, Hellinger, WD) = (BD_storage/batch_count, Hellinger_storage/batch_count, WD_storage/batch_count)

    return BD, Hellinger, WD, policy_model.laac_shallow, policy_model.laac_deep


@ops_ingredient.capture
def compute_implicit_het(rb, agent_count, policy_model, batch_size, lr_IHet, epochs_IHet, z_features_IHet, kl_weight_IHet, _log, 
                         continue_IHet_training=False, 
                         pretrained_VAE=None, 
                         compute_IHet=False,
                         debug_mode=True,
                         het_compute_mode="fixed_obs_all_acts"):
    """
    Calculate implicit heterogeneity distance between agents
    Args:
        continue_IHet_training (bool): Whether to continue training existing VAE model
        pretrained_VAE (ConditionalLinearVAE): Pre-trained VAE model, only used when continue_training=True
        compute_IHet (bool): Whether to compute implicit heterogeneity distance
        debug_mode (bool): Whether to enable debug mode (plot loss curves/print time etc.)
        het_compute_mode (str): Method for computing implicit heterogeneity
            - "full_combination": Fully separate obs and act, compute all possible combinations
            - "fixed_obs_all_acts": Fix sampled obs, iterate through all possible discrete actions for each obs
    """
    device = next(policy_model.parameters()).device

    # ---------------------------1. train Implicit Het VAE------------------------------
    print("Starting to train implicit heterogeneity VAE")
    dataset = rbDataSet_IHet(rb)

    encoder_input_size = dataset.data[0].shape[-1]
    encoder_condition_size = dataset.data[1].shape[-1]
    decoder_input_size = dataset.data[2].shape[-1]
    reconstruct_size = dataset.data[3].shape[-1]
   
    if continue_IHet_training and pretrained_VAE is not None:
        VAE_model = pretrained_VAE
        print("Using pre-trained VAE model to continue training")
    else:
        VAE_model = BiHierarchicalVAE(z_features_IHet, encoder_input_size, encoder_condition_size, reconstruct_size)
        print("Creating new VAE model")
    print("Implicit Het VAE_model is:")
    print(VAE_model)
    VAE_model.to(device)
    optimizer = torch.optim.Adam(VAE_model.parameters(), lr=lr_IHet)

    criterion = nn.MSELoss(reduction="sum")

    def final_loss(bce_loss, mu, logvar):
        """
        This function will add the reconstruction loss (BCELoss) and the 
        KL-Divergence.
        KL-Divergence = 0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2)
        """
        BCE = bce_loss 
        KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
        return BCE + kl_weight_IHet*KLD
    
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

    # Debug mode related variables
    if debug_mode:
        train_loss = []
        previous_time = time.time()

    # Training loop
    for epoch in tqdm(range(epochs_IHet)):
        running_loss = 0.0
        total_samples = 0
        
        for i, (encoder_in, encoder_condition, decoder_in, y) in enumerate(dataloader):
            current_batch_size = encoder_in.size(0)
            optimizer.zero_grad()
            
            reconstruction, mu, logvar = VAE_model(
                encoder_in.to(device), 
                encoder_condition.to(device), 
                decoder_in.to(device)
            )
            
            bce_loss = criterion(reconstruction, y.to(device))
            loss = final_loss(bce_loss, mu, logvar)
            loss.backward()
            optimizer.step()
            
            # Accumulate loss and record sample count
            running_loss += loss.item()
            total_samples += current_batch_size
            
            # Debug mode: Record average loss per batch
            if debug_mode:
                batch_avg_loss = loss.item() / current_batch_size
                train_loss.append(batch_avg_loss)

    # Debug mode: Plot and save loss curve
    if debug_mode:
        current_time = time.time()
        print("VAE Training time:", current_time - previous_time)
        print(f"Final Average Loss: {running_loss/total_samples:.6f}")
        

    if compute_IHet:
        # ---------------------------2. Use VAE to calculate distances------------------------------
        print("--------------Starting to calculate implicit heterogeneity distance--------------")
        print(f"Using calculation mode: {het_compute_mode}")
        previous_time = time.time() if debug_mode else None
        
        # Get dimensions of local_state and act
        local_state_dim = rb["local_state"].shape[1]
        act_dim = rb["act"].shape[1]
        
        # Get all sampled local_state
        all_local_states = torch.cat([batch[1][:, :local_state_dim].to(device) for batch in dataloader], dim=0)
        
     
        
        # Build dataset based on different calculation modes
        if het_compute_mode == "full_combination":
            print("Using method 1: Integrate all possible (obs,act) combinations")
            # Get all sampled act
            all_acts = torch.cat([batch[1][:, local_state_dim:].to(device) for batch in dataloader], dim=0)
            
            # Randomly resample to reduce computation
            if len(all_local_states) * len(all_acts) > 1000000:  # If too many combinations
                sample_size_ls = min(1000, len(all_local_states))
                sample_size_act = min(1000, len(all_acts))
                ls_indices = torch.randperm(len(all_local_states))[:sample_size_ls]
                act_indices = torch.randperm(len(all_acts))[:sample_size_act]
                all_local_states = all_local_states[ls_indices]
                all_acts = all_acts[act_indices]
            
            # Construct all (obs,act) combinations
            o_expanded = all_local_states.unsqueeze(1).expand(-1, len(all_acts), -1)
            a_expanded = all_acts.unsqueeze(0).expand(len(all_local_states), -1, -1)
            conditions = torch.cat([o_expanded, a_expanded], dim=2).reshape(-1, local_state_dim + act_dim)
            
        else:  # "fixed_obs_all_acts"
            print("Using method 2: Fix obs, iterate through all possible discrete actions for each obs")
            # Create all possible one-hot actions
            all_acts = torch.eye(act_dim).to(device)
            
            # Construct all (obs,act) combinations
            o_expanded = all_local_states.unsqueeze(1).expand(-1, act_dim, -1)
            a_expanded = all_acts.unsqueeze(0).expand(len(all_local_states), -1, -1)
            conditions = torch.cat([o_expanded, a_expanded], dim=2).reshape(-1, local_state_dim + act_dim)
        
        # Initialize distance matrix storage
        N = agent_count
        BD_storage = torch.zeros([N, N]).to(device)
        Hellinger_storage = torch.zeros([N, N]).to(device)
        WD_storage = torch.zeros([N, N]).to(device)
        
        # Compute in batches to avoid memory overflow
        compute_batch_size = 10000
        total_batches = (len(conditions) + compute_batch_size - 1) // compute_batch_size
        
        for b in tqdm(range(total_batches)):
            start_idx = b * compute_batch_size
            end_idx = min((b + 1) * compute_batch_size, len(conditions))
            
            condition_batch = conditions[start_idx:end_idx]
            batch_size = condition_batch.size(0)
            
            # Prepare VAE input for each agent
            agent_ids = torch.eye(N).to(device).unsqueeze(1).expand(-1, batch_size, -1)
            condition_expanded = condition_batch.unsqueeze(0).expand(N, -1, -1)
            
            # Use VAE encoding
            with torch.no_grad():
                _, mus, sigmas = VAE_model.encode(agent_ids, condition_expanded)
            
            # Calculate distance matrix
            BD = calculate_N_Gaussians_BD(mus.transpose(0, 1), sigmas.transpose(0, 1))
            Hellinger = calculate_N_Gaussians_Hellinger_through_BD(BD)
            WD = calculate_N_Gaussians_WD(mus.transpose(0, 1), sigmas.transpose(0, 1))
            
            # Accumulate distances
            BD_storage += BD * batch_size
            Hellinger_storage += Hellinger * batch_size
            WD_storage += WD * batch_size
        
        # Calculate average distances
        total_samples = len(conditions)
        BD = BD_storage / total_samples
        Hellinger = Hellinger_storage / total_samples
        WD = WD_storage / total_samples
        
        if debug_mode:
            current_time = time.time()
            print("Time needed for distance calculation:", current_time - previous_time)

        return BD, Hellinger, WD, VAE_model
    else:
        return None, None, None, VAE_model



@ops_ingredient.capture
def compute_dynamic_parameter_sharing(WD_matrix, previous_clustering=None, previous_network_assignments=None, 
                                     policy_model=None, merge_mode='majority', _log=None, seed=None):
    """
    Dynamic parameter sharing algorithm based on heterogeneous distances
    
    Args:
        WD_matrix: Wasserstein distance matrix between agents [N,N]
        previous_clustering: Previous clustering results - cluster labels for each agent [N]
        previous_network_assignments: Previous network assignments - network IDs used by each agent [N]
        policy_model: Agent policy network model, used for parameter merging and splitting operations
        merge_mode: Network merging logic, options:
                   'majority' - Use the strategy with the most agents
                   'random' - Randomly select a strategy
                   'average' - Use averaging method
                   'weighted' - Use weighted averaging method
    
    Returns:
        new_cluster_assignments: New clustering information - cluster labels for each agent [N]
        network_assignments: Network indices that each agent should use [N]
    """
    device = WD_matrix.device if hasattr(WD_matrix, 'device') else 'cpu'
    N = WD_matrix.shape[0]  # Number of agents
    
    # Preserve original AffinityPropagation clustering
    affinity_matrix = -WD_matrix.cpu().numpy()
    print("Using AffinityPropagation for agent clustering...")
    ap = AffinityPropagation(affinity='precomputed', random_state=seed)
    cluster_new = ap.fit_predict(affinity_matrix)
    
    unique_new_clusters = np.unique(cluster_new)
    num_new_clusters = len(unique_new_clusters)
    print(f"Clustering result: {num_new_clusters} clusters")
    
    # If first time clustering
    if previous_clustering is None or previous_network_assignments is None:
        print("First clustering, using cluster labels directly as network assignments...")
        return cluster_new, cluster_new.copy()
    
    # Get old clustering and network assignments
    cluster_old = previous_clustering
    network_old = previous_network_assignments
    
    unique_old_clusters = np.unique(cluster_old)
    num_old_clusters = len(unique_old_clusters)
    print(f"Previous clustering: {num_old_clusters} clusters")
    print(f"New clustering: {num_new_clusters} clusters")
    
    # Calculate similarity between new and old clusters (overlap matrix)
    def compute_similarity_matrix(cluster_new, cluster_old):
        overlap_matrix = np.zeros((len(np.unique(cluster_new)), len(np.unique(cluster_old))))
        for i in range(len(cluster_new)):
            new_idx = np.where(unique_new_clusters == cluster_new[i])[0][0]
            old_idx = np.where(unique_old_clusters == cluster_old[i])[0][0]
            overlap_matrix[new_idx, old_idx] += 1
        return overlap_matrix
    
    # Build overlap matrix
    overlap_matrix = compute_similarity_matrix(cluster_new, cluster_old)
    
    # Initialize new network assignments
    network_new = np.zeros_like(cluster_new)
    
    # Handle different cases based on the relationship between old and new clusters
    if num_old_clusters == num_new_clusters:
        print("Case 1: Old clusters equal new clusters, establishing one-to-one mapping...")
        
        # Use Hungarian algorithm to find the best match
        row_ind, col_ind = linear_sum_assignment(-overlap_matrix)
        
        # Build mapping from new to old clusters
        new_to_old_mapping = {}
        for i, j in zip(row_ind, col_ind):
            if i < len(unique_new_clusters) and j < len(unique_old_clusters):
                new_to_old_mapping[unique_new_clusters[i]] = unique_old_clusters[j]
        
        # Build mapping from old to network
        old_cluster_to_network = {}
        for old_cluster in unique_old_clusters:
            agents = np.where(cluster_old == old_cluster)[0]
            networks = network_old[agents]
            most_common_network = Counter(networks).most_common(1)[0][0]
            old_cluster_to_network[old_cluster] = most_common_network
        
        # Assign network to each agent
        for i in range(N):
            new_cluster = cluster_new[i]
            old_cluster = new_to_old_mapping.get(new_cluster, new_cluster)
            network_new[i] = old_cluster_to_network.get(old_cluster, new_cluster)
    
    elif num_old_clusters < num_new_clusters:
        print("Case 2: Old clusters less than new clusters, handling network splitting...")
        
        # Use Hungarian algorithm to find the best match
        # Note: When new clusters are more than old clusters, some new clusters cannot be matched to old clusters
        row_ind, col_ind = linear_sum_assignment(-overlap_matrix)
        
        # Build mapping from new to old clusters
        new_to_old_mapping = {}
        for i, j in zip(row_ind, col_ind):
            if i < len(unique_new_clusters) and j < len(unique_old_clusters):
                new_cluster = unique_new_clusters[i]
                old_cluster = unique_old_clusters[j]
                # Only establish mapping when there is obvious overlap
                if overlap_matrix[i, j] > 0:
                    new_to_old_mapping[new_cluster] = old_cluster
        
        # Build mapping from old to network
        old_cluster_to_network = {}
        for old_cluster in unique_old_clusters:
            agents = np.where(cluster_old == old_cluster)[0]
            networks = network_old[agents]
            most_common_network = Counter(networks).most_common(1)[0][0]
            old_cluster_to_network[old_cluster] = most_common_network
        
        # Assign network to each new cluster
        new_to_network_mapping = {}
        next_network_id = max(network_old) + 1 if len(network_old) > 0 else 0
        
        for new_cluster in unique_new_clusters:
            if new_cluster in new_to_old_mapping:
                # Matched new clusters use corresponding old network
                old_cluster = new_to_old_mapping[new_cluster]
                new_to_network_mapping[new_cluster] = old_cluster_to_network[old_cluster]
            else:
                # Unmatched new clusters need to split network
                new_to_network_mapping[new_cluster] = next_network_id
                
                # Find most similar old cluster
                new_idx = np.where(unique_new_clusters == new_cluster)[0][0]
                if np.max(overlap_matrix[new_idx]) > 0:
                    best_old_idx = np.argmax(overlap_matrix[new_idx])
                    best_old_cluster = unique_old_clusters[best_old_idx]
                    source_network = old_cluster_to_network[best_old_cluster]
                    
                    # Execute split operation
                    if policy_model is not None:
                        print(f"Split operation: Copy from network {source_network} to new network {next_network_id}")
                        policy_model.copy_parameters(source_network, next_network_id)
                else:
                    # If no overlap, find the closest cluster
                    agents_in_new = np.where(cluster_new == new_cluster)[0]
                    if len(agents_in_new) > 0:
                        center_agent = agents_in_new[0]
                        
                        min_dist = float('inf')
                        closest_network = None
                        
                        for old_cluster in unique_old_clusters:
                            agents_in_old = np.where(cluster_old == old_cluster)[0]
                            if len(agents_in_old) > 0:
                                dist = float(WD_matrix[center_agent, agents_in_old[0]].cpu().numpy())
                                if dist < min_dist:
                                    min_dist = dist
                                    closest_network = old_cluster_to_network[old_cluster]
                        
                        if policy_model is not None and closest_network is not None:
                            print(f"Split operation: Copy from network {closest_network} to new network {next_network_id}")
                            policy_model.copy_parameters(closest_network, next_network_id)
                
                next_network_id += 1
        
        # Assign network to each agent
        for i in range(N):
            new_cluster = cluster_new[i]
            network_new[i] = new_to_network_mapping[new_cluster]
    
    else:  # num_old_clusters > num_new_clusters
        print("Case 3: Old clusters more than new clusters, handling network merging...")
        
        # Use Hungarian algorithm to find the best match
        # Note: When old clusters are more than new clusters, some old clusters cannot be matched to new clusters
        # We need to transpose the overlap matrix to adapt to this situation
        row_ind, col_ind = linear_sum_assignment(-overlap_matrix.T)
        
        # Build mapping from old to new clusters
        old_to_new_mapping = {}
        for i, j in zip(row_ind, col_ind):
            if i < len(unique_old_clusters) and j < len(unique_new_clusters):
                old_cluster = unique_old_clusters[i]
                new_cluster = unique_new_clusters[j]
                # Only establish mapping when there is obvious overlap
                if overlap_matrix[j, i] > 0:
                    old_to_new_mapping[old_cluster] = new_cluster
        
        # Build mapping from old to network
        old_cluster_to_network = {}
        for old_cluster in unique_old_clusters:
            agents = np.where(cluster_old == old_cluster)[0]
            networks = network_old[agents]
            most_common_network = Counter(networks).most_common(1)[0][0]
            old_cluster_to_network[old_cluster] = most_common_network
        
        # Assign network to each new cluster
        new_to_network_mapping = {}
        next_network_id = 0
        
        for new_cluster in unique_new_clusters:
            new_to_network_mapping[new_cluster] = next_network_id
            
            # Find all old clusters mapped to this new cluster
            merged_old_clusters = [old_cluster for old_cluster, nc in old_to_new_mapping.items() if nc == new_cluster]
            if merged_old_clusters:
                # Need to merge networks
                networks_to_merge = [old_cluster_to_network[oc] for oc in merged_old_clusters]
                
                if len(networks_to_merge) > 1 and policy_model is not None:
                    print(f"Merge operation: Merge networks {networks_to_merge} to network {next_network_id}")
                    
                    if merge_mode == 'majority':
                        # Keep network with most agents
                        network_counts = {}
                        for network in networks_to_merge:
                            network_counts[network] = np.sum(network_old == network)
                        
                        max_network = max(network_counts, key=network_counts.get)
                        print(f"Copy parameters: From network {max_network} to network {next_network_id}")
                        policy_model.copy_parameters(max_network, next_network_id)
                    
                    elif merge_mode == 'random':
                        # Randomly select a network
                        chosen_network = np.random.choice(networks_to_merge)
                        print(f"Random selection: Copy network {chosen_network} to network {next_network_id}")
                        policy_model.copy_parameters(chosen_network, next_network_id)
                    
                    elif merge_mode == 'average':
                        # Average merge all networks
                        print(f"Average merge: Merge networks {networks_to_merge} to network {next_network_id}")
                        policy_model.merge_parameters(networks_to_merge, target=next_network_id, merge_mode='average')
                    
                    elif merge_mode == 'weighted':
                        # Weighted merge all networks
                        print(f"Weighted merge: Merge networks {networks_to_merge} to network {next_network_id}")
                        policy_model.merge_parameters(networks_to_merge, target=next_network_id, merge_mode='weighted')
                
                elif len(networks_to_merge) == 1 and policy_model is not None:
                    # No need to merge if only one network
                    source_network = networks_to_merge[0]
                    if source_network != next_network_id:
                        print(f"Copy parameters: From network {source_network} to network {next_network_id}")
                        policy_model.copy_parameters(source_network, next_network_id)
            
            next_network_id += 1
        
        # Assign network to each agent
        for i in range(N):
            new_cluster = cluster_new[i]
            network_new[i] = new_to_network_mapping[new_cluster]
    
    print(f"Original clustering: {cluster_new}")
    print(f"Network assignment: {network_new}")
    
    return cluster_new, network_new

# Visualize clustering results
@ops_ingredient.capture
def visualize_clusters(WD_matrix, cluster_assignments, cluster_centers, _run):
    """Visualize clustering results"""
    N = WD_matrix.shape[0]
    
    # Use multidimensional scaling to transform distance matrix into 2D coordinates
    from sklearn.manifold import MDS
    mds = MDS(n_components=2, dissimilarity='precomputed', random_state=42)
    positions = mds.fit_transform(WD_matrix.cpu().numpy())
    
    # Plot clustering results
    plt.figure(figsize=(10, 8))
    
    # Get unique cluster IDs
    unique_clusters = np.unique(cluster_assignments)
    
    # Assign colors to each cluster
    colors = list(mcolors.TABLEAU_COLORS.values())[:len(unique_clusters)]
    
    # Plot positions and clusters of each agent
    for cluster_id, color in zip(unique_clusters, colors):
        cluster_points = positions[cluster_assignments == cluster_id]
        plt.scatter(cluster_points[:, 0], cluster_points[:, 1], c=color, label=f'Cluster {cluster_id}')
    
    # Mark cluster centers
    for i, center_id in enumerate(cluster_centers):
        plt.scatter(positions[center_id, 0], positions[center_id, 1], 
                   c='black', marker='*', s=200, label=f'Center {center_id}' if i == 0 else "")
    
    # Add agent ID labels to each point
    for i in range(N):
        plt.annotate(str(i), (positions[i, 0], positions[i, 1]), fontsize=8)
    
    plt.title('Agent Clustering Results')
    plt.legend()
    
    # Save image
    with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as tmpfile:
        plt.savefig(tmpfile, format="png")
        _run.add_artifact(tmpfile.name, f"cluster_visualization.png")
    
    plt.close()

@ops_ingredient.capture
def compute_normal_SND(rb, agent_count, policy_model, batch_size, lr, epochs, z_features, kl_weight, _log, policy_mode='Pure', policy_submodel=None):
    
    device = next(policy_model.parameters()).device

    dataset = rbDataSet(rb)
    
    
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

    # ------------------------------1. Only use agents' models to calculate dij------------------------------
    BD_storage = None
    Hellinger_storage = None
    batch_count = len(dataloader)
    
    for i, (_, policy_input, _, _, other_1, other_2) in enumerate(dataloader):


        # A. make hidden samples and get policy inputs
        hidden_sample_batch = policy_input.to(device) # [B,H]

        B = policy_input.shape[0]
        N = agent_count                        # agents_count
        E = policy_model.laac_shallow.shape[0]  # envs_count

        if policy_mode == 'Pure':
            assert policy_submodel is None, "policy_submodel must be None when policy_mode is 'Pure'"
            hidden_sample_batch = policy_input.to(device) # [B,H]
            hidden_sample_list = [hidden_sample_batch.clone().unsqueeze(1).expand(-1, E, -1) for _ in range(N)]     # (N)[B,E,H]
        elif policy_mode == 'WithID':
            assert policy_submodel is None, "policy_submodel must be None when policy_mode is 'WithID'"
            hidden_sample_batch = policy_input.to(device)  # [B,H]
            origin_obs = hidden_sample_batch.unsqueeze(0).expand(N, -1, -1)  # [N,B,H]
            agent_ids = torch.eye(N, device=device).unsqueeze(1).expand(-1, B, -1)  # [N,B,N]
            obs_expanded = hidden_sample_batch.unsqueeze(0).expand(N, -1, -1)  # [N,B,H]
            agent_inputs = torch.cat((obs_expanded, agent_ids), dim=-1)  # [N,B,H+N]
            hidden_sample_list = [agent_inputs[i].unsqueeze(1).expand(-1, E, -1) for i in range(N)]  # (N)[B,E,H+N]
        elif policy_mode == 'WithIHet':
            assert policy_submodel is not None, "policy_submodel must be provided when policy_mode is 'WithIHet'"
            hidden_sample_batch = policy_input.to(device)  # [B,H], origin_obs
            origin_obs = hidden_sample_batch.unsqueeze(0).expand(N, -1, -1)  # [N,B,H]
            IHet_inputs_wo_id = other_2.to(device)  # [B,X], local_state + local_act
            agent_ids = torch.eye(N, device=device).to(torch.float32)  # [N,N]
            agent_ids_batch = agent_ids.unsqueeze(1).expand(-1, B, -1)  # [N,B,N]
            IHet_inputs_expanded = IHet_inputs_wo_id.unsqueeze(0).expand(N, -1, -1)  # [N,B,X]
            ihet_features, _, _ = policy_submodel.encode(agent_ids_batch, IHet_inputs_expanded)  # [N,B,Z]
            obs_expanded = hidden_sample_batch.unsqueeze(0).expand(N, -1, -1)  # [N,B,H]
            agent_inputs = torch.cat((obs_expanded, ihet_features), dim=-1)  # [N,B,H+Z]
            hidden_sample_list = [agent_inputs[i].unsqueeze(1).expand(-1, E, -1) for i in range(N)]  # (N)[B,E,H+Z]

        # B. generate policy outputs
        with torch.no_grad():
            act_probs = policy_model.get_act_probs(hidden_sample_list) # (N)[B,E,A]
    
        act_probs_batch = torch.stack(act_probs, dim=0)                # [N, B, E, A]

        # C. compute Dij
        N, B, E, A = act_probs_batch.shape

        if BD_storage is None: BD_storage = torch.zeros([N,N]).to(device)
        if Hellinger_storage is None: Hellinger_storage = torch.zeros([N,N]).to(device)


        # Compute BD
        AA = act_probs_batch[:,:,0,:].permute(1,0,2) # [N,B,A] --> [B,N,A]
        AA1 = AA.unsqueeze(1).expand(-1, N, -1, -1)
        AA2 = AA1.transpose(1,2)

        AA1 = torch.sqrt(AA1)
        AA2 = torch.sqrt(AA2)

        BC = (AA1 * AA2).sum(dim=-1)   # [B,N,N]
        BD_batch = - torch.log(BC)
        BD = torch.mean(BD_batch,dim=0)   # [N,N]

        Hellinger = calculate_N_Gaussians_Hellinger_through_BD(BD)

        
        BD_storage += BD
        Hellinger_storage += Hellinger


        assert True

    return BD_storage/batch_count, Hellinger_storage/batch_count





def calculate_N_Gaussians_BD(mus, log_vars):
    """
    Calculate Bhattacharyya distance between N multivariate Gaussian distributions
    Assumes input mu.shape = [B,N,D] where N is number of distributions, D is dimension
    Assumes input sigma.shape = [B,N,D] where covariance matrix is diagonal, sigma stores log_var
    Output is [N,N] matrix
    """
    assert mus.dim() == 3
    assert mus.shape == log_vars.shape
    B,N,D = mus.shape

    mus = mus.transpose(0,1).reshape(N,B*D)             # [N, B*D]
    log_vars = log_vars.transpose(0,1).reshape(N,B*D)   # [N, B*D]

    # basic term
    mus1 = mus.unsqueeze(1).expand(-1, N, -1)
    mus2 = mus1.transpose(0,1)
    log_vars1 = log_vars.unsqueeze(1).expand(-1, N, -1)
    log_vars2 = log_vars1.transpose(0,1)

    sigmas1 = torch.exp(log_vars1)
    sigmas2 = torch.exp(log_vars2)
    mean_sigmas = (sigmas1 + sigmas2)/2
    
    mu1_mu2_square = (mus1 - mus2) ** 2
    term3_frac = (sigmas1 + sigmas2)

    # main term
    term1 =( 0.5 * torch.log(mean_sigmas).sum(dim=-1) )/ (B*D)
    term2 =( -0.25 * ( log_vars1 + log_vars2 ).sum(dim=-1) ) / (B*D)
    term3 =((0.25 * ( mu1_mu2_square / term3_frac )).sum(dim=-1)) / (B*D)

    return (term1 + term2 + term3) 


def calculate_N_Gaussians_Hellinger_through_BD(BD, max_value=15.0):
    # Clipping BD to avoid numerical instability
    BD_clipped = torch.clamp(BD, min=-max_value, max=max_value)

    term1 = torch.exp(-BD_clipped)
    result = torch.sqrt(1 - term1)

    return result


def calculate_N_Gaussians_WD(mus, log_vars):
    """
    Calculate Wasserstein distance between N multivariate Gaussian distributions
    Assumes input mu.shape = [B,N,D] where B is batch size for averaging, N is number of Gaussians, D is dimension
    Assumes input log_vars.shape = [B,N,D] where covariance matrix is diagonal, log_vars stores log_var
    Output is [N,N] matrix
    """
    assert mus.dim() == 3
    assert mus.shape == log_vars.shape
    B, N, D = mus.shape

    mus1 = mus.unsqueeze(1).expand(-1, N, -1, -1)
    mus2 = mus1.transpose(1,2)

    vars1 = torch.exp(log_vars).unsqueeze(1).expand(-1, N, -1, -1)
    vars2 = vars1.transpose(1,2)

    mean_diff = (mus1 - mus2) ** 2

    term1 = mean_diff.sum(dim=-1)   # [B,N,N]
    term2 = (((vars1**0.5) - (vars2**0.5)) ** 2).sum(dim=-1) 

    output = (term1 + term2) ** 0.5

    return torch.mean(output,dim=0)   # [N,N]





